import torch
input_indices = torch.tensor([[3,2,6,5,9,12,0,0],
                             [11,2,4,8,13,7,10,12]])

print(input_indices.shape)

num_embeddings = 14
embedding_dim = 3

emb = torch.nn.Embedding(num_embeddings,embedding_dim)
print(emb.weight)

out_vector = emb(input_indices)
print(out_vector)